{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import keras\n",
    "import numpy as np\n",
    "from keras.models import Sequential, Model\n",
    "from keras.layers import *\n",
    "from keras.preprocessing.sequence import pad_sequences "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "处理一下tensorflow的问题："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "import keras.backend as K\n",
    "\n",
    "if K.backend() == \"tensorflow\":\n",
    "    config = K.tf.ConfigProto()\n",
    "    config.gpu_options.allow_growth = True\n",
    "    session = K.tf.Session(config=config)\n",
    "    K.set_session(session)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "读数据："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "with open(\"data/all.txt\") as f:\n",
    "    all_data = [line.strip().split(\";\") for line in f]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "data_x_team_id_raw = all_data[0::9]\n",
    "data_y_raw = all_data[1::9]\n",
    "data_x_team_abbr_raw = all_data[2::9]\n",
    "data_x_home_min_raw = all_data[3::9]\n",
    "data_x_home_id_raw = all_data[4::9]\n",
    "data_x_home_name_raw = all_data[5::9]\n",
    "data_x_visitor_min_raw = all_data[6::9]\n",
    "data_x_visitor_id_raw = all_data[7::9]\n",
    "data_x_visitor_name_raw = all_data[8::9]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "def flatten(x):\n",
    "    for seq in x:\n",
    "        for s in seq:\n",
    "            yield s"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "把球队名称和ID对应："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "id2team = dict(zip(flatten(data_x_team_id_raw), flatten(data_x_team_abbr_raw)))\n",
    "id2player = dict(zip(flatten(data_x_home_id_raw + data_x_visitor_id_raw), \n",
    "                     flatten(data_x_home_name_raw + data_x_visitor_name_raw)))\n",
    "team2id = dict(zip(flatten(data_x_team_abbr_raw), flatten(data_x_team_id_raw)))\n",
    "player2id = dict(zip(flatten(data_x_home_name_raw + data_x_visitor_name_raw),\n",
    "                     flatten(data_x_home_id_raw + data_x_visitor_id_raw)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total_teams 30\n"
     ]
    }
   ],
   "source": [
    "print \"total_teams\", len(id2team)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total_players 898\n"
     ]
    }
   ],
   "source": [
    "print \"total_players\", len(id2player)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "将球队id和队员id序列化方便Embbedding："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "tid2index = {tid: idx for idx, tid in enumerate(id2team)}\n",
    "index2tid = {idx: tid for idx, tid in enumerate(id2team)}\n",
    "\n",
    "pid2index = {pid: idx+1 for idx, pid in enumerate(id2player)}\n",
    "index2pid = {idx+1: pid for idx, pid in enumerate(id2player)}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "将时间转化为秒："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "def str2time(ms):\n",
    "    m, s = ms.split(\":\")\n",
    "    return int(m) * 60 + int(s)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "读取数据："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "data_x_team_id = np.array(map(lambda x: [tid2index[tid] for tid in x], data_x_team_id_raw))\n",
    "data_x_home_id   = pad_sequences(map(lambda x: [pid2index[pid] for pid in x], data_x_home_id_raw), padding=\"post\", maxlen=13)\n",
    "data_x_vistor_id = pad_sequences(map(lambda x: [pid2index[pid] for pid in x], data_x_visitor_id_raw), padding=\"post\", maxlen=13)\n",
    "data_x_home_min    = pad_sequences(map(lambda x: [str2time(ms) for ms in x], data_x_home_min_raw), padding=\"post\", maxlen=13)\n",
    "data_x_visitor_min = pad_sequences(map(lambda x: [str2time(ms) for ms in x], data_x_visitor_min_raw), padding=\"post\", maxlen=13)\n",
    "\n",
    "data_x_home_min = 5 * data_x_home_min.astype(K.floatx()) / data_x_home_min.sum(axis=-1)[:,None]\n",
    "data_x_visitor_min = 5 * data_x_visitor_min.astype(K.floatx()) / data_x_visitor_min.sum(axis=-1)[:,None]\n",
    "\n",
    "data_y = np.array(data_y_raw, dtype=int)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "构造比分预测模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "x_t = Input((2,), name=\"team_id\") \n",
    "\n",
    "x_h_id = Input((13,), name=\"home_player_id\")\n",
    "x_h_min = Input((13,1), name=\"home_player_time\")\n",
    "\n",
    "x_v_id = Input((13,), name=\"visitor_player_id\")\n",
    "x_v_min = Input((13,1), name=\"visitor_player_time\")\n",
    "\n",
    "emb_dim = 256\n",
    "\n",
    "team_emb = Sequential(name=\"team_emb\")\n",
    "team_emb.add(Embedding(input_dim=30, output_dim=emb_dim, input_length=2))\n",
    "team_emb.add(Flatten())\n",
    "\n",
    "player_emb = Sequential(name=\"player_emb\")\n",
    "player_emb.add(Embedding(input_dim=len(id2player)+1, output_dim=emb_dim, input_length=13))\n",
    "\n",
    "feat_t = team_emb(x_t)\n",
    "\n",
    "feat_h_id = player_emb(x_h_id)\n",
    "feat_h = dot([x_h_min, feat_h_id], axes=1, name=\"home_player_sum\")\n",
    "feat_h = Reshape((emb_dim,), name=\"home_feat\")(feat_h)\n",
    "\n",
    "feat_v_id = player_emb(x_v_id)\n",
    "feat_v = dot([x_v_min, feat_v_id], axes=1, name=\"visitor_player_sum\")\n",
    "feat_v = Reshape((emb_dim,), name=\"visitor_feat\")(feat_v)\n",
    "\n",
    "feat = concatenate([feat_t, feat_h, feat_v], name=\"all_feat\")\n",
    "\n",
    "hid = Dense(256, activation=\"relu\", name=\"hidden_1\")(feat)\n",
    "hid = Dropout(0.2, name=\"dropout_1\")(hid)\n",
    "hid = Dense(128, activation=\"relu\", name=\"hidden_2\")(hid)\n",
    "hid = Dropout(0.2, name=\"dropout_2\")(hid)\n",
    "score = Dense(2, activation=\"relu\", name=\"score\")(hid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "model = Model(inputs=[x_t, x_h_id, x_h_min, x_v_id, x_v_min], outputs=score)\n",
    "\n",
    "\n",
    "def win(y_true, y_pred):\n",
    "    return K.mean(K.equal(K.argmax(y_pred, axis=-1), K.argmax(y_true, axis=-1)))\n",
    "\n",
    "model.compile(optimizer=\"adam\", loss=\"mae\", metrics=[win])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(1105)\n",
    "\n",
    "idx = np.arange(len(data_y))\n",
    "np.random.shuffle(idx)\n",
    "\n",
    "train_idx = idx[600:]\n",
    "valid_idx = idx[:600]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "6632"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "hist = model.fit([data_x_team_id[train_idx], \n",
    "                  data_x_home_id[train_idx], data_x_home_min[train_idx, :, None], \n",
    "                  data_x_vistor_id[train_idx], data_x_visitor_min[train_idx, :, None]],\n",
    "                 data_y[train_idx],\n",
    "                 validation_data = ([data_x_team_id[valid_idx], \n",
    "                                     data_x_home_id[valid_idx], data_x_home_min[valid_idx, :, None], \n",
    "                                     data_x_vistor_id[valid_idx], data_x_visitor_min[valid_idx, :, None]],\n",
    "                                    data_y[valid_idx]),\n",
    "                 verbose=0,\n",
    "                 epochs=30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[10.165847600301106,\n",
       " 9.9543780390421546,\n",
       " 9.8001865132649737,\n",
       " 10.289466730753581,\n",
       " 9.228977038065592,\n",
       " 11.430291798909506,\n",
       " 9.3323127492268885,\n",
       " 9.1152853902180997,\n",
       " 9.1859752019246415,\n",
       " 9.4495480346679681,\n",
       " 9.5419914118448901,\n",
       " 9.3206252034505201,\n",
       " 9.0971596018473306,\n",
       " 9.0549177932739262,\n",
       " 9.2275110244750973,\n",
       " 9.3106695048014316,\n",
       " 9.1169172541300458,\n",
       " 9.3980016962687181,\n",
       " 9.1642890930175778,\n",
       " 9.702613741556803,\n",
       " 9.2596215184529616,\n",
       " 9.3760626475016284,\n",
       " 8.9202607472737636,\n",
       " 9.2350659434000644,\n",
       " 9.2720768737792962,\n",
       " 10.19280756632487,\n",
       " 9.3464718373616531,\n",
       " 9.4991256459554041,\n",
       " 9.2756049474080395,\n",
       " 9.1515356826782224]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hist.history[\"val_loss\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "预测测试数据正确率："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.656666666667\n",
      "0.563333333333\n"
     ]
    }
   ],
   "source": [
    "data_t = model.predict([data_x_team_id[valid_idx], \n",
    "                        data_x_home_id[valid_idx], data_x_home_min[valid_idx , :, None], \n",
    "                        data_x_vistor_id[valid_idx], data_x_visitor_min[valid_idx , :, None]])\n",
    "\n",
    "print np.mean(data_t.argmax(axis=-1) == data_y[valid_idx].argmax(axis=-1))\n",
    "print np.mean(0 == data_y[valid_idx].argmax(axis=-1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "测试集上的预测比分："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true,
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 118.31840515,  106.11000824],\n",
       "       [  99.41278076,   93.19446564],\n",
       "       [ 100.38309479,   99.36334991],\n",
       "       ..., \n",
       "       [  99.76564026,   93.42744446],\n",
       "       [  91.10274506,   97.74150848],\n",
       "       [ 103.73945618,   95.22878265]], dtype=float32)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_t"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "真实比分："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[116,  92],\n",
       "       [107, 112],\n",
       "       [ 94, 100],\n",
       "       ..., \n",
       "       [105, 109],\n",
       "       [112, 100],\n",
       "       [ 95,  94]])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_y[valid_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "data_x_team_abbr = np.array(data_x_team_abbr_raw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([['SAC', 'LAL'],\n",
       "       ['LAL', 'HOU'],\n",
       "       ['MIN', 'NYK'],\n",
       "       ..., \n",
       "       ['WAS', 'MIL'],\n",
       "       ['CHA', 'GSW'],\n",
       "       ['SAC', 'WAS']], \n",
       "      dtype='|S3')"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_x_team_abbr[valid_idx]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练集上的比分和队伍："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "data_p = model.predict([data_x_team_id[train_idx], \n",
    "                        data_x_home_id[train_idx], data_x_home_min[train_idx , :, None], \n",
    "                        data_x_vistor_id[train_idx], data_x_visitor_min[train_idx , :, None]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true,
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 111.05187225,   93.65827179],\n",
       "       [ 103.43552399,  105.09658813],\n",
       "       [ 108.43019104,  107.60032654],\n",
       "       ..., \n",
       "       [  85.79730988,   81.65973663],\n",
       "       [ 107.5667038 ,  106.44389343],\n",
       "       [  99.8336792 ,   98.61862183]], dtype=float32)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([['HOU', 'PHI'],\n",
       "       ['DEN', 'ATL'],\n",
       "       ['PHX', 'MIN'],\n",
       "       ..., \n",
       "       ['DET', 'IND'],\n",
       "       ['LAL', 'BKN'],\n",
       "       ['MIN', 'UTA']], \n",
       "      dtype='|S3')"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_x_team_abbr[train_idx]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "真实比分："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[120,  98],\n",
       "       [105, 119],\n",
       "       [107, 104],\n",
       "       ..., \n",
       "       [ 77,  88],\n",
       "       [105, 114],\n",
       "       [ 92,  94]])"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_y[train_idx]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "模型结构："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<svg height=\"702pt\" viewBox=\"0.00 0.00 594.00 702.00\" width=\"594pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 698)\">\n",
       "<title>G</title>\n",
       "<polygon fill=\"white\" points=\"-4,4 -4,-698 590,-698 590,4 -4,4\" stroke=\"none\"/>\n",
       "<!-- 140572603596240 -->\n",
       "<g class=\"node\" id=\"node1\"><title>140572603596240</title>\n",
       "<polygon fill=\"none\" points=\"96,-657.5 96,-693.5 269,-693.5 269,-657.5 96,-657.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"182.5\" y=\"-671.8\">home_player_id: InputLayer</text>\n",
       "</g>\n",
       "<!-- 140572603639952 -->\n",
       "<g class=\"node\" id=\"node4\"><title>140572603639952</title>\n",
       "<polygon fill=\"none\" points=\"205.5,-584.5 205.5,-620.5 351.5,-620.5 351.5,-584.5 205.5,-584.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"278.5\" y=\"-598.8\">player_emb: Sequential</text>\n",
       "</g>\n",
       "<!-- 140572603596240&#45;&gt;140572603639952 -->\n",
       "<g class=\"edge\" id=\"edge1\"><title>140572603596240-&gt;140572603639952</title>\n",
       "<path d=\"M205.739,-657.313C218.238,-648.069 233.822,-636.543 247.358,-626.532\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"249.516,-629.289 255.475,-620.529 245.354,-623.661 249.516,-629.289\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140572603596368 -->\n",
       "<g class=\"node\" id=\"node2\"><title>140572603596368</title>\n",
       "<polygon fill=\"none\" points=\"287.5,-657.5 287.5,-693.5 463.5,-693.5 463.5,-657.5 287.5,-657.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"375.5\" y=\"-671.8\">visitor_player_id: InputLayer</text>\n",
       "</g>\n",
       "<!-- 140572603596368&#45;&gt;140572603639952 -->\n",
       "<g class=\"edge\" id=\"edge2\"><title>140572603596368-&gt;140572603639952</title>\n",
       "<path d=\"M352.019,-657.313C339.389,-648.069 323.643,-636.543 309.966,-626.532\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"311.901,-623.611 301.765,-620.529 307.767,-629.26 311.901,-623.611\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140572603596560 -->\n",
       "<g class=\"node\" id=\"node3\"><title>140572603596560</title>\n",
       "<polygon fill=\"none\" points=\"0,-584.5 0,-620.5 187,-620.5 187,-584.5 0,-584.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"93.5\" y=\"-598.8\">home_player_time: InputLayer</text>\n",
       "</g>\n",
       "<!-- 140572603639824 -->\n",
       "<g class=\"node\" id=\"node7\"><title>140572603639824</title>\n",
       "<polygon fill=\"none\" points=\"122.5,-511.5 122.5,-547.5 268.5,-547.5 268.5,-511.5 122.5,-511.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"195.5\" y=\"-525.8\">home_player_sum: Dot</text>\n",
       "</g>\n",
       "<!-- 140572603596560&#45;&gt;140572603639824 -->\n",
       "<g class=\"edge\" id=\"edge3\"><title>140572603596560-&gt;140572603639824</title>\n",
       "<path d=\"M117.932,-584.494C131.332,-575.166 148.129,-563.474 162.662,-553.358\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"164.706,-556.199 170.914,-547.614 160.707,-550.454 164.706,-556.199\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140572603639952&#45;&gt;140572603639824 -->\n",
       "<g class=\"edge\" id=\"edge4\"><title>140572603639952-&gt;140572603639824</title>\n",
       "<path d=\"M258.408,-584.313C247.805,-575.243 234.635,-563.977 223.089,-554.1\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"225.281,-551.37 215.407,-547.529 220.731,-556.689 225.281,-551.37\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140575373044240 -->\n",
       "<g class=\"node\" id=\"node8\"><title>140575373044240</title>\n",
       "<polygon fill=\"none\" points=\"287,-511.5 287,-547.5 436,-547.5 436,-511.5 287,-511.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"361.5\" y=\"-525.8\">visitor_player_sum: Dot</text>\n",
       "</g>\n",
       "<!-- 140572603639952&#45;&gt;140575373044240 -->\n",
       "<g class=\"edge\" id=\"edge6\"><title>140572603639952-&gt;140575373044240</title>\n",
       "<path d=\"M298.592,-584.313C309.195,-575.243 322.365,-563.977 333.911,-554.1\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"336.269,-556.689 341.593,-547.529 331.719,-551.37 336.269,-556.689\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140572603638352 -->\n",
       "<g class=\"node\" id=\"node5\"><title>140572603638352</title>\n",
       "<polygon fill=\"none\" points=\"370,-584.5 370,-620.5 561,-620.5 561,-584.5 370,-584.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"465.5\" y=\"-598.8\">visitor_player_time: InputLayer</text>\n",
       "</g>\n",
       "<!-- 140572603638352&#45;&gt;140575373044240 -->\n",
       "<g class=\"edge\" id=\"edge5\"><title>140572603638352-&gt;140575373044240</title>\n",
       "<path d=\"M440.589,-584.494C426.926,-575.166 409.8,-563.474 394.982,-553.358\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"396.8,-550.361 386.568,-547.614 392.853,-556.143 396.8,-550.361\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140572603596176 -->\n",
       "<g class=\"node\" id=\"node6\"><title>140572603596176</title>\n",
       "<polygon fill=\"none\" points=\"454.5,-511.5 454.5,-547.5 582.5,-547.5 582.5,-511.5 454.5,-511.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"518.5\" y=\"-525.8\">team_id: InputLayer</text>\n",
       "</g>\n",
       "<!-- 140572603639632 -->\n",
       "<g class=\"node\" id=\"node9\"><title>140572603639632</title>\n",
       "<polygon fill=\"none\" points=\"447,-438.5 447,-474.5 586,-474.5 586,-438.5 447,-438.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"516.5\" y=\"-452.8\">team_emb: Sequential</text>\n",
       "</g>\n",
       "<!-- 140572603596176&#45;&gt;140572603639632 -->\n",
       "<g class=\"edge\" id=\"edge7\"><title>140572603596176-&gt;140572603639632</title>\n",
       "<path d=\"M518.016,-511.313C517.79,-503.289 517.515,-493.547 517.263,-484.569\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"520.76,-484.426 516.98,-474.529 513.763,-484.623 520.76,-484.426\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140573144624848 -->\n",
       "<g class=\"node\" id=\"node10\"><title>140573144624848</title>\n",
       "<polygon fill=\"none\" points=\"140.5,-438.5 140.5,-474.5 268.5,-474.5 268.5,-438.5 140.5,-438.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"204.5\" y=\"-452.8\">home_feat: Reshape</text>\n",
       "</g>\n",
       "<!-- 140572603639824&#45;&gt;140573144624848 -->\n",
       "<g class=\"edge\" id=\"edge8\"><title>140572603639824-&gt;140573144624848</title>\n",
       "<path d=\"M197.679,-511.313C198.696,-503.289 199.931,-493.547 201.069,-484.569\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"204.556,-484.89 202.341,-474.529 197.612,-484.009 204.556,-484.89\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140573144209040 -->\n",
       "<g class=\"node\" id=\"node11\"><title>140573144209040</title>\n",
       "<polygon fill=\"none\" points=\"296,-438.5 296,-474.5 427,-474.5 427,-438.5 296,-438.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"361.5\" y=\"-452.8\">visitor_feat: Reshape</text>\n",
       "</g>\n",
       "<!-- 140575373044240&#45;&gt;140573144209040 -->\n",
       "<g class=\"edge\" id=\"edge9\"><title>140575373044240-&gt;140573144209040</title>\n",
       "<path d=\"M361.5,-511.313C361.5,-503.289 361.5,-493.547 361.5,-484.569\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"365,-484.529 361.5,-474.529 358,-484.529 365,-484.529\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140572603815568 -->\n",
       "<g class=\"node\" id=\"node12\"><title>140572603815568</title>\n",
       "<polygon fill=\"none\" points=\"296,-365.5 296,-401.5 427,-401.5 427,-365.5 296,-365.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"361.5\" y=\"-379.8\">all_feat: Concatenate</text>\n",
       "</g>\n",
       "<!-- 140572603639632&#45;&gt;140572603815568 -->\n",
       "<g class=\"edge\" id=\"edge10\"><title>140572603639632-&gt;140572603815568</title>\n",
       "<path d=\"M479.374,-438.494C457.963,-428.686 430.848,-416.266 408.024,-405.811\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"409.41,-402.596 398.861,-401.614 406.495,-408.96 409.41,-402.596\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140573144624848&#45;&gt;140572603815568 -->\n",
       "<g class=\"edge\" id=\"edge11\"><title>140573144624848-&gt;140572603815568</title>\n",
       "<path d=\"M242.105,-438.494C263.792,-428.686 291.257,-416.266 314.376,-405.811\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"315.988,-408.923 323.657,-401.614 313.103,-402.545 315.988,-408.923\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140573144209040&#45;&gt;140572603815568 -->\n",
       "<g class=\"edge\" id=\"edge12\"><title>140573144209040-&gt;140572603815568</title>\n",
       "<path d=\"M361.5,-438.313C361.5,-430.289 361.5,-420.547 361.5,-411.569\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"365,-411.529 361.5,-401.529 358,-411.529 365,-411.529\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140572603817872 -->\n",
       "<g class=\"node\" id=\"node13\"><title>140572603817872</title>\n",
       "<polygon fill=\"none\" points=\"307.5,-292.5 307.5,-328.5 415.5,-328.5 415.5,-292.5 307.5,-292.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"361.5\" y=\"-306.8\">hidden_1: Dense</text>\n",
       "</g>\n",
       "<!-- 140572603815568&#45;&gt;140572603817872 -->\n",
       "<g class=\"edge\" id=\"edge13\"><title>140572603815568-&gt;140572603817872</title>\n",
       "<path d=\"M361.5,-365.313C361.5,-357.289 361.5,-347.547 361.5,-338.569\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"365,-338.529 361.5,-328.529 358,-338.529 365,-338.529\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140572603817936 -->\n",
       "<g class=\"node\" id=\"node14\"><title>140572603817936</title>\n",
       "<polygon fill=\"none\" points=\"299,-219.5 299,-255.5 424,-255.5 424,-219.5 299,-219.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"361.5\" y=\"-233.8\">dropout_1: Dropout</text>\n",
       "</g>\n",
       "<!-- 140572603817872&#45;&gt;140572603817936 -->\n",
       "<g class=\"edge\" id=\"edge14\"><title>140572603817872-&gt;140572603817936</title>\n",
       "<path d=\"M361.5,-292.313C361.5,-284.289 361.5,-274.547 361.5,-265.569\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"365,-265.529 361.5,-255.529 358,-265.529 365,-265.529\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140573143359568 -->\n",
       "<g class=\"node\" id=\"node15\"><title>140573143359568</title>\n",
       "<polygon fill=\"none\" points=\"307.5,-146.5 307.5,-182.5 415.5,-182.5 415.5,-146.5 307.5,-146.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"361.5\" y=\"-160.8\">hidden_2: Dense</text>\n",
       "</g>\n",
       "<!-- 140572603817936&#45;&gt;140573143359568 -->\n",
       "<g class=\"edge\" id=\"edge15\"><title>140572603817936-&gt;140573143359568</title>\n",
       "<path d=\"M361.5,-219.313C361.5,-211.289 361.5,-201.547 361.5,-192.569\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"365,-192.529 361.5,-182.529 358,-192.529 365,-192.529\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140573143359888 -->\n",
       "<g class=\"node\" id=\"node16\"><title>140573143359888</title>\n",
       "<polygon fill=\"none\" points=\"299,-73.5 299,-109.5 424,-109.5 424,-73.5 299,-73.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"361.5\" y=\"-87.8\">dropout_2: Dropout</text>\n",
       "</g>\n",
       "<!-- 140573143359568&#45;&gt;140573143359888 -->\n",
       "<g class=\"edge\" id=\"edge16\"><title>140573143359568-&gt;140573143359888</title>\n",
       "<path d=\"M361.5,-146.313C361.5,-138.289 361.5,-128.547 361.5,-119.569\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"365,-119.529 361.5,-109.529 358,-119.529 365,-119.529\" stroke=\"black\"/>\n",
       "</g>\n",
       "<!-- 140573143295952 -->\n",
       "<g class=\"node\" id=\"node17\"><title>140573143295952</title>\n",
       "<polygon fill=\"none\" points=\"318.5,-0.5 318.5,-36.5 404.5,-36.5 404.5,-0.5 318.5,-0.5\" stroke=\"black\"/>\n",
       "<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"361.5\" y=\"-14.8\">score: Dense</text>\n",
       "</g>\n",
       "<!-- 140573143359888&#45;&gt;140573143295952 -->\n",
       "<g class=\"edge\" id=\"edge17\"><title>140573143359888-&gt;140573143295952</title>\n",
       "<path d=\"M361.5,-73.3129C361.5,-65.2895 361.5,-55.5475 361.5,-46.5691\" fill=\"none\" stroke=\"black\"/>\n",
       "<polygon fill=\"black\" points=\"365,-46.5288 361.5,-36.5288 358,-46.5289 365,-46.5288\" stroke=\"black\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>"
      ],
      "text/plain": [
       "<IPython.core.display.SVG object>"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython.display import SVG\n",
    "from keras.utils.vis_utils import model_to_dot, plot_model\n",
    "\n",
    "plot_model(model, to_file=\"model.png\")\n",
    "SVG(model_to_dot(model).create(prog='dot', format='svg'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "关公战秦琼："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "names1 = \"Mike Miller, LeBron James, Chris Bosh, Dwyane Wade, Mario Chalmers,\" \\\n",
    "         \" Ray Allen, Shane Battier, Chris Andersen, Udonis Haslem\"\n",
    "pid1 = [[pid2index[player2id[name]] for name in names1.split(\", \")]]\n",
    "min1 = [[19,47,28,39,40,20,29,19,2]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "names2 = \"LeBron James, JR Smith, Kevin Love, Kyrie Irving, Tristan Thompson, Richard Jefferson, Mo Williams, Iman Shumpert\"\n",
    "pid2 = [[pid2index[player2id[name]] for name in names2.split(\", \")]]\n",
    "min2 = [[47,39,30,43,32,26,5,19]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "pid1 = pad_sequences(pid1, maxlen=13, padding=\"post\")\n",
    "pid2 = pad_sequences(pid2, maxlen=13, padding=\"post\")\n",
    "min1 = pad_sequences(min1, maxlen=13, padding=\"post\")\n",
    "min2 = pad_sequences(min2, maxlen=13, padding=\"post\")\n",
    "\n",
    "min1 = 5 * min1.astype(np.float32) / min1.sum()\n",
    "min2 = 5 * min2.astype(np.float32) / min2.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tid2index[team2id[\"MIA\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "14"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tid2index[team2id[\"CLE\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 102.45483398,  100.24012756]], dtype=float32)"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict([np.array([[14, 4]]), pid2, min2[..., None], pid1, min1[..., None]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 110.89669037,  106.30657196]], dtype=float32)"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict([np.array([[4, 14]]), pid1, min1[..., None], pid2, min2[..., None]])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
